"""Tests for eval_lib.testing.fake_cloud_client."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import copy
from io import BytesIO
import unittest
from eval_lib.tests import fake_cloud_client
from six import assertCountEqual
from six import b as six_b


class FakeStorageClientTest(unittest.TestCase):
    def test_list_blobs(self):
        all_blobs = [
            "some_blob",
            "dataset/dev_dataset.csv",
            "dataset/dev/img1.png",
            "dataset/dev/img2.png",
        ]
        client = fake_cloud_client.FakeStorageClient(all_blobs)
        assertCountEqual(self, all_blobs, client.list_blobs())
        assertCountEqual(
            self,
            ["dataset/dev_dataset.csv", "dataset/dev/img1.png", "dataset/dev/img2.png"],
            client.list_blobs("dataset/dev"),
        )
        assertCountEqual(
            self,
            ["dataset/dev/img1.png", "dataset/dev/img2.png"],
            client.list_blobs("dataset/dev/"),
        )

    def test_get_blob(self):
        client = fake_cloud_client.FakeStorageClient(
            {"some_blob": "some_content", "blob2": "another_content"}
        )
        self.assertIsNone(client.get_blob("blob3"))
        buf = BytesIO()
        client.get_blob("some_blob").download_to_file(buf)
        self.assertEqual(six_b("some_content"), buf.getvalue())


class FakeDatastoreKeyTest(unittest.TestCase):
    def test_flat_path(self):
        key1 = fake_cloud_client.FakeDatastoreKey("abc", "1")
        self.assertTupleEqual(("abc", "1"), key1.flat_path)
        key2 = fake_cloud_client.FakeDatastoreKey("def", "xyz", parent=key1)
        self.assertTupleEqual(("abc", "1", "def", "xyz"), key2.flat_path)

    def test_equality(self):
        key1a = fake_cloud_client.FakeDatastoreKey("abc", "1")
        key1b = fake_cloud_client.FakeDatastoreKey("abc", "1")
        key2a = fake_cloud_client.FakeDatastoreKey("def", "xyz", parent=key1a)
        key2b = fake_cloud_client.FakeDatastoreKey("def", "xyz", parent=key1a)
        # key equal to self
        self.assertTrue(key1a == key1a)
        self.assertFalse(key1a != key1a)
        # key equal to the same key
        self.assertTrue(key1a == key1b)
        self.assertFalse(key1a != key1b)
        self.assertTrue(key2a == key2b)
        self.assertFalse(key2a != key2b)
        # key different from other key
        self.assertFalse(key1a == key2a)
        self.assertTrue(key1a != key2a)
        # key not equal to tuple
        self.assertTrue(key1a != key1a.flat_path)
        self.assertFalse(key1a == key1a.flat_path)


class FakeDatastoreEntityTest(unittest.TestCase):
    def test_key(self):
        entity = fake_cloud_client.make_entity(("abc", "1"))
        self.assertEqual(entity.key, fake_cloud_client.FakeDatastoreKey("abc", "1"))

    def test_equality_keys(self):
        entity1a = fake_cloud_client.make_entity(("abc", "1"))
        entity1b = fake_cloud_client.make_entity(("abc", "1"))
        entity2 = fake_cloud_client.make_entity(("abc", "2"))
        self.assertFalse(entity1a == entity2)
        self.assertTrue(entity1a != entity2)
        self.assertTrue(entity1a == entity1b)
        self.assertFalse(entity1b != entity1b)

    def test_equality_dict(self):
        entity1 = fake_cloud_client.make_entity(("abc", "1"))
        entity1["k1"] = "v1"
        entity2 = fake_cloud_client.make_entity(("abc", "1"))
        entity2["k1"] = "v2"
        entity3 = fake_cloud_client.make_entity(("abc", "1"))
        entity1["k1"] = "v1"
        entity1["k2"] = "v2"
        # compare to self
        self.assertTrue(entity1 == entity1)
        self.assertFalse(entity1 != entity1)
        self.assertTrue(entity2 == entity2)
        self.assertFalse(entity2 != entity2)
        self.assertTrue(entity3 == entity3)
        self.assertFalse(entity3 != entity3)
        # compare to others
        self.assertFalse(entity1 == entity2)
        self.assertTrue(entity1 != entity2)
        self.assertFalse(entity1 == entity3)
        self.assertTrue(entity1 != entity3)
        self.assertFalse(entity2 == entity3)
        self.assertTrue(entity2 != entity3)

    def test_copy(self):
        entity1 = fake_cloud_client.make_entity(("abc", "1"))
        entity1["k1"] = ["v1"]
        self.assertEqual(entity1.key, fake_cloud_client.FakeDatastoreKey("abc", "1"))
        self.assertEqual(dict(entity1), {"k1": ["v1"]})
        entity2 = copy.copy(entity1)
        entity2["k1"].append("v2")
        entity2["k3"] = "v3"
        self.assertIsInstance(entity2, fake_cloud_client.FakeDatastoreEntity)
        self.assertEqual(entity1.key, fake_cloud_client.FakeDatastoreKey("abc", "1"))
        self.assertEqual(dict(entity1), {"k1": ["v1", "v2"]})
        self.assertEqual(entity2.key, fake_cloud_client.FakeDatastoreKey("abc", "1"))
        self.assertEqual(dict(entity2), {"k1": ["v1", "v2"], "k3": "v3"})

    def test_deep_copy(self):
        entity1 = fake_cloud_client.make_entity(("abc", "1"))
        entity1["k1"] = ["v1"]
        self.assertEqual(entity1.key, fake_cloud_client.FakeDatastoreKey("abc", "1"))
        self.assertEqual(dict(entity1), {"k1": ["v1"]})
        entity2 = copy.deepcopy(entity1)
        entity2["k1"].append("v2")
        entity2["k3"] = "v3"
        self.assertIsInstance(entity2, fake_cloud_client.FakeDatastoreEntity)
        self.assertEqual(entity1.key, fake_cloud_client.FakeDatastoreKey("abc", "1"))
        self.assertEqual(dict(entity1), {"k1": ["v1"]})
        self.assertEqual(entity2.key, fake_cloud_client.FakeDatastoreKey("abc", "1"))
        self.assertEqual(dict(entity2), {"k1": ["v1", "v2"], "k3": "v3"})


class FakeDatastoreClientTest(unittest.TestCase):
    def setUp(self):
        self._client = fake_cloud_client.FakeDatastoreClient()
        self._key1 = self._client.key("abc", "def")
        self._key2 = self._client.key("qwe", "rty", parent=self._key1)
        self._entity1 = self._client.entity(self._key1)
        self._entity1["k1"] = "v1"
        self._entity2 = self._client.entity(self._key2)
        self._entity2["k2"] = "v2"
        self._entity2["k3"] = "v3"

    def test_make_key(self):
        self.assertTupleEqual(("abc", "def"), self._key1.flat_path)
        self.assertTupleEqual(("abc", "def", "qwe", "rty"), self._key2.flat_path)

    def test_make_entity(self):
        self.assertTupleEqual(("abc", "def"), self._entity1.key.flat_path)

    def test_put_entity(self):
        self.assertDictEqual({}, self._client.entities)
        self._client.put(self._entity1)
        self.assertDictEqual({self._key1: self._entity1}, self._client.entities)
        self._client.put(self._entity2)
        self.assertDictEqual(
            {self._key1: self._entity1, self._key2: self._entity2},
            self._client.entities,
        )

    def test_get_entity(self):
        self._client.put(self._entity1)
        self._client.put(self._entity2)
        self.assertEqual(self._entity1, self._client.get(self._key1))
        self.assertEqual(self._entity2, self._client.get(self._key2))

    def test_write_batch(self):
        with self._client.no_transact_batch() as batch:
            batch.put(self._entity1)
            batch.put(self._entity2)
        assertCountEqual(self, [self._key1, self._key2], self._client.entities.keys())
        self.assertEqual(self._key1, self._client.entities[self._key1].key)
        self.assertDictEqual({"k1": "v1"}, dict(self._client.entities[self._key1]))
        self.assertEqual(self._key2, self._client.entities[self._key2].key)
        self.assertDictEqual(
            {"k2": "v2", "k3": "v3"}, dict(self._client.entities[self._key2])
        )

    def test_overwrite_values(self):
        client = fake_cloud_client.FakeDatastoreClient()
        key1 = client.key("abc", "def")
        entity1 = client.entity(key1)
        entity1["k1"] = "v1"
        entity2 = client.entity(key1)
        entity2["k1"] = "v2"
        entity2["k2"] = "v3"
        with client.no_transact_batch() as batch:
            batch.put(entity1)
        assertCountEqual(self, [key1], client.entities.keys())
        self.assertEqual(key1, client.entities[key1].key)
        self.assertDictEqual({"k1": "v1"}, dict(client.entities[key1]))
        with client.no_transact_batch() as batch:
            batch.put(entity2)
        assertCountEqual(self, [key1], client.entities.keys())
        self.assertEqual(key1, client.entities[key1].key)
        self.assertDictEqual({"k1": "v2", "k2": "v3"}, dict(client.entities[key1]))

    def test_query_fetch_all(self):
        entity1 = fake_cloud_client.make_entity(("abc", "1"))
        entity1["k1"] = "v1"
        entity2 = fake_cloud_client.make_entity(("abc", "1", "def", "2"))
        entity2["k2"] = "v2"
        client = fake_cloud_client.FakeDatastoreClient([entity1, entity2])
        assertCountEqual(self, [entity1, entity2], client.query_fetch())

    def test_query_fetch_kind_filter(self):
        entity1 = fake_cloud_client.make_entity(("abc", "1"))
        entity1["k1"] = "v1"
        entity2 = fake_cloud_client.make_entity(("abc", "1", "def", "2"))
        entity2["k2"] = "v2"
        client = fake_cloud_client.FakeDatastoreClient([entity1, entity2])
        assertCountEqual(self, [entity1], client.query_fetch(kind="abc"))
        assertCountEqual(self, [entity2], client.query_fetch(kind="def"))

    def test_query_fetch_ancestor_filter(self):
        entity1 = fake_cloud_client.make_entity(("abc", "1", "def", "2"))
        entity1["k1"] = "v1"
        entity2 = fake_cloud_client.make_entity(("xyz", "3", "qwe", "4"))
        entity2["k2"] = "v2"
        client = fake_cloud_client.FakeDatastoreClient([entity1, entity2])
        assertCountEqual(
            self, [entity1], client.query_fetch(ancestor=client.key("abc", "1"))
        )
        assertCountEqual(
            self, [entity2], client.query_fetch(ancestor=client.key("xyz", "3"))
        )

    def test_query_fetch_ancestor_and_kind_filter(self):
        entity1 = fake_cloud_client.make_entity(("abc", "1", "def", "2"))
        entity1["k1"] = "v1"
        entity2 = fake_cloud_client.make_entity(("abc", "1", "xyz", "3"))
        entity2["k2"] = "v2"
        entity3 = fake_cloud_client.make_entity(("def", "4"))
        entity3["k2"] = "v2"
        client = fake_cloud_client.FakeDatastoreClient([entity1, entity2, entity3])
        assertCountEqual(
            self,
            [entity1],
            client.query_fetch(kind="def", ancestor=client.key("abc", "1")),
        )

    def test_query_fetch_data_filter(self):
        entity1 = fake_cloud_client.make_entity(("abc", "1"))
        entity1["k1"] = "v1"
        entity2 = fake_cloud_client.make_entity(("abc", "2"))
        entity2["k1"] = "v2"
        entity2["k2"] = "v2"
        entity3 = fake_cloud_client.make_entity(("abc", "3"))
        entity3["k2"] = "v3"
        client = fake_cloud_client.FakeDatastoreClient([entity1, entity2, entity3])
        assertCountEqual(
            self, [entity1], client.query_fetch(filters=[("k1", "=", "v1")])
        )
        assertCountEqual(
            self, [entity2], client.query_fetch(filters=[("k1", ">", "v1")])
        )
        assertCountEqual(
            self, [entity1, entity2], client.query_fetch(filters=[("k1", ">=", "v1")])
        )
        assertCountEqual(
            self, [entity2], client.query_fetch(filters=[("k2", "<", "v3")])
        )
        assertCountEqual(
            self, [entity2, entity3], client.query_fetch(filters=[("k2", "<=", "v3")])
        )
        assertCountEqual(
            self,
            [entity2],
            client.query_fetch(filters=[("k1", ">=", "v1"), ("k2", "<=", "v3")]),
        )


class FakeDatastoreClientTransactionTest(unittest.TestCase):
    def setUp(self):
        self._client = fake_cloud_client.FakeDatastoreClient()
        self._key1 = self._client.key("abc", "def")
        self._key2 = self._client.key("qwe", "rty", parent=self._key1)
        self._key3 = self._client.key("123", "456")
        self._entity1 = self._client.entity(self._key1)
        self._entity1["k1"] = "v1"
        self._entity2 = self._client.entity(self._key2)
        self._entity2["k2"] = "v2"
        self._entity2["k3"] = "v3"
        self._entity3 = self._client.entity(self._key3)
        self._entity3["k4"] = "v4"
        self._entity3["k5"] = "v5"
        self._entity3["k6"] = "v6"
        self._client.put(self._entity1)
        self._client.put(self._entity2)
        self._client.put(self._entity3)
        # verify datastore content
        assertCountEqual(
            self, [self._key1, self._key2, self._key3], self._client.entities.keys()
        )
        self.assertDictEqual({"k1": "v1"}, dict(self._client.entities[self._key1]))
        self.assertDictEqual(
            {"k2": "v2", "k3": "v3"}, dict(self._client.entities[self._key2])
        )
        self.assertDictEqual(
            {"k4": "v4", "k5": "v5", "k6": "v6"},
            dict(self._client.entities[self._key3]),
        )

    def test_transaction_write_only_no_concurrent(self):
        key4 = self._client.key("zxc", "vbn")
        entity4 = self._client.entity(key4)
        entity4["k7"] = "v7"
        entity3_upd = self._client.entity(self._key3)
        entity3_upd["k4"] = "upd_v4"
        with self._client.transaction() as transaction:
            # first write in transaction
            transaction.put(entity4)
            # second write in transaction
            transaction.put(entity3_upd)
        # verify datastore content
        assertCountEqual(
            self,
            [self._key1, self._key2, self._key3, key4],
            self._client.entities.keys(),
        )
        self.assertDictEqual({"k1": "v1"}, dict(self._client.entities[self._key1]))
        self.assertDictEqual(
            {"k2": "v2", "k3": "v3"}, dict(self._client.entities[self._key2])
        )
        self.assertDictEqual(
            {"k4": "upd_v4", "k5": "v5", "k6": "v6"},
            dict(self._client.entities[self._key3]),
        )
        self.assertDictEqual({"k7": "v7"}, dict(self._client.entities[key4]))

    def test_transaction_read_write_no_concurrent(self):
        key4 = self._client.key("zxc", "vbn")
        entity4 = self._client.entity(key4)
        entity4["k7"] = "v7"
        entity3_upd = self._client.entity(self._key3)
        entity3_upd["k4"] = "upd_v4"
        with self._client.transaction() as transaction:
            # reading in transaction always returns data snapshot before transaction
            read_entity = self._client.get(self._key3, transaction=transaction)
            self.assertDictEqual(
                {"k4": "v4", "k5": "v5", "k6": "v6"}, dict(read_entity)
            )
            # first write in transaction
            transaction.put(entity3_upd)
            # second write in transaction
            transaction.put(entity4)
            # reading in transaction always returns data snapshot before transaction
            read_entity = self._client.get(self._key3, transaction=transaction)
            self.assertDictEqual(
                {"k4": "v4", "k5": "v5", "k6": "v6"}, dict(read_entity)
            )
        # verify datastore content
        assertCountEqual(
            self,
            [self._key1, self._key2, self._key3, key4],
            self._client.entities.keys(),
        )
        self.assertDictEqual({"k1": "v1"}, dict(self._client.entities[self._key1]))
        self.assertDictEqual(
            {"k2": "v2", "k3": "v3"}, dict(self._client.entities[self._key2])
        )
        self.assertDictEqual(
            {"k4": "upd_v4", "k5": "v5", "k6": "v6"},
            dict(self._client.entities[self._key3]),
        )
        self.assertDictEqual({"k7": "v7"}, dict(self._client.entities[key4]))

    def test_transaction_read_write_concurrent_not_intersecting(self):
        key4 = self._client.key("zxc", "vbn")
        entity4 = self._client.entity(key4)
        entity4["k7"] = "v7"
        entity3_upd = self._client.entity(self._key3)
        entity3_upd["k4"] = "upd_v4"
        entity1_upd = self._client.entity(self._key1)
        entity1_upd["k1"] = "upd_v1"
        with self._client.transaction() as transaction:
            # reading in transaction always returns data snapshot before transaction
            read_entity = self._client.get(self._key3, transaction=transaction)
            self.assertDictEqual(
                {"k4": "v4", "k5": "v5", "k6": "v6"}, dict(read_entity)
            )
            # first write in transaction
            transaction.put(entity3_upd)
            # modify some data which are not references in the transaction
            self._client.put(entity1_upd)
            # second write in transaction
            transaction.put(entity4)
            # reading in transaction always returns data snapshot before transaction
            read_entity = self._client.get(self._key3, transaction=transaction)
            self.assertDictEqual(
                {"k4": "v4", "k5": "v5", "k6": "v6"}, dict(read_entity)
            )
        # verify datastore content
        assertCountEqual(
            self,
            [self._key1, self._key2, self._key3, key4],
            self._client.entities.keys(),
        )
        self.assertDictEqual({"k1": "upd_v1"}, dict(self._client.entities[self._key1]))
        self.assertDictEqual(
            {"k2": "v2", "k3": "v3"}, dict(self._client.entities[self._key2])
        )
        self.assertDictEqual(
            {"k4": "upd_v4", "k5": "v5", "k6": "v6"},
            dict(self._client.entities[self._key3]),
        )
        self.assertDictEqual({"k7": "v7"}, dict(self._client.entities[key4]))

    def test_transaction_write_concurrent(self):
        key4 = self._client.key("zxc", "vbn")
        entity4 = self._client.entity(key4)
        entity4["k7"] = "v7"
        entity3_upd = self._client.entity(self._key3)
        entity3_upd["k4"] = "upd_v4"
        entity3_upd_no_transact = self._client.entity(self._key3)
        entity3_upd_no_transact["k4"] = "another_v4"
        reached_end_of_transaction = False
        with self.assertRaises(Exception):
            with self._client.transaction() as transaction:
                # first write in transaction
                transaction.put(entity3_upd)
                # modify some data which are not references in the transaction
                self._client.put(entity3_upd_no_transact)
                # second write in transaction
                transaction.put(entity4)
                reached_end_of_transaction = True
        self.assertTrue(reached_end_of_transaction)
        # verify datastore content
        assertCountEqual(
            self, [self._key1, self._key2, self._key3], self._client.entities.keys()
        )
        self.assertDictEqual({"k1": "v1"}, dict(self._client.entities[self._key1]))
        self.assertDictEqual(
            {"k2": "v2", "k3": "v3"}, dict(self._client.entities[self._key2])
        )
        self.assertDictEqual(
            {"k4": "another_v4", "k5": "v5", "k6": "v6"},
            dict(self._client.entities[self._key3]),
        )

    def test_transaction_read_concurrent(self):
        key4 = self._client.key("zxc", "vbn")
        entity4 = self._client.entity(key4)
        entity4["k7"] = "v7"
        entity3_upd_no_transact = self._client.entity(self._key3)
        entity3_upd_no_transact["k4"] = "another_v4"
        reached_end_of_transaction = False
        with self.assertRaises(Exception):
            with self._client.transaction() as transaction:
                # write in transaction
                transaction.put(entity4)
                # read in transaction
                read_entity = self._client.get(self._key3, transaction=transaction)
                self.assertDictEqual(
                    {"k4": "v4", "k5": "v5", "k6": "v6"}, dict(read_entity)
                )
                # modify some data which are not references in the transaction
                self._client.put(entity3_upd_no_transact)
                reached_end_of_transaction = True
        self.assertTrue(reached_end_of_transaction)
        # verify datastore content
        assertCountEqual(
            self, [self._key1, self._key2, self._key3], self._client.entities.keys()
        )
        self.assertDictEqual({"k1": "v1"}, dict(self._client.entities[self._key1]))
        self.assertDictEqual(
            {"k2": "v2", "k3": "v3"}, dict(self._client.entities[self._key2])
        )
        self.assertDictEqual(
            {"k4": "another_v4", "k5": "v5", "k6": "v6"},
            dict(self._client.entities[self._key3]),
        )


if __name__ == "__main__":
    unittest.main()
